import heapq
import copy
import random

class Search(object):

    @staticmethod
    def astar_search(init,goal,s1,s2,heuristic,options,abstraction,reachability):

        fringe = []
        path = [] 
        visited = set() 
        heapq.heappush(fringe,(0,[init,path[:],copy.deepcopy(visited),0]))
        while len(fringe) > 0: 
            cur_node, cur_path, cur_visited, cost = heapq.heappop(fringe)[1]
            cur_visited.add(cur_node.id)
            cur_path.append(cur_node.id)
            if cur_node == goal:
                path = cur_path[:]
                break 

            for state_id in reachability[cur_node.id]: 
                if state_id not in cur_visited:
                    f = cost + heuristic(cur_node,abstraction.get_state_from_id(state_id))
                    heapq.heappush(fringe,
                                  (f,
                                  [abstraction.get_state_from_id(state_id),
                                    cur_path[:],copy.deepcopy(cur_visited),
                                    cost + options[cur_node.id][state_id].cost]
                                    ))
        
        if type(path[0]) != int: # do this  only for interfaces
            if len(path) < 2:
                return path,False
            start, end = 0, len(path)-1
            for idx, p in enumerate(path):
                if s1.id in list(p):
                    start = idx
                if s2.id in list(p):
                    end = idx
                    break

            path = path[start:end+1]
        
        solution = []
        for i in range(len(path)-1):
            solution.append(options[path[i]][path[i+1]])
        
        if len(solution) < 1 and type(path[0]) != int: # Conditional for interfaces to check if we need to solve it using destination HL state directly.
            return path, False
        if len(solution) < 1: 
            return path, False

        return [solution], True


    @staticmethod
    def beam_search(init, goal, heuristic, options, abstraction, k = 50, n = 20):

        return [[options[abstraction.abstract_states[1315]][abstraction.abstract_states[789]],options[abstraction.abstract_states[789]][abstraction.abstract_states[643]],options[abstraction.abstract_states[643]][abstraction.abstract_states[369]],options[abstraction.abstract_states[369]][abstraction.abstract_states[505]],options[abstraction.abstract_states[505]][abstraction.abstract_states[1304]],options[abstraction.abstract_states[1304]][abstraction.abstract_states[681]]]]

        solutions = []
        fringe = []
        hl_plans = []
        path = []
        visited = set()
        heapq.heappush(fringe,(0,[init,path[:],copy.deepcopy(visited),0,False,False]))
        heapq.heappush(fringe,(1,[goal,path[:],copy.deepcopy(visited),0,False,False]))
        states = list(abstraction.abstract_states.values())
        states.remove(init)
        states.remove(goal)
        nsamples = 5 # min(len(states),int(k/2))
        sampled_states = random.sample(states,nsamples)
        for i in range(nsamples):
            heapq.heappush(fringe,(2+i,[sampled_states[i],path[:],copy.deepcopy(visited),0,False,False]))
        # while k > 0 and len(fringe) > 0: # add one more terminating condition
        while len(fringe) > 0 and n > 0: # add one more terminating condition
            new_fringe = []
            i = 0
            while i < k and i < len(fringe):
                new_fringe.append(heapq.heappop(fringe))
                i+= 1
            fringe = []
            while len(new_fringe) > 0:
                current,path,visited,cost,reached_goal,reached_init = heapq.heappop(new_fringe)[1]
                plan_found = False
                flag_add = False
                if current.id == goal.id and not reached_goal:
                    path.append(current)
                    flag_add = True
                    reached_goal = True
                    if reached_init:
                        plan_found = True
                    else:
                        visited.add(current)
                        current = path[0]
                        if len(visited) > 0:
                            visited.remove(current)
                elif current.id == init.id and not reached_init: 
                    reached_init = True
                    flag_add = True
                    if reached_goal: 
                        path.insert(0,current)
                        plan_found = True
                    else:
                        visited.add(current)
                        path.append(current)
                        current = path[0]
                        path = path[::-1]
                        if len(visited) > 0:
                            visited.remove(current)
                if not plan_found:
                    if current not in visited:
                        visited.add(current)
                        if not flag_add:
                            if reached_goal:
                                path.insert(0,current)
                            else:
                                path.append(current)
                        for option in options[current].values(): 
                            if option.dest == current:
                                continue
                            cost += option.cost
                            if reached_init:
                                h = cost + heuristic(current,goal)
                                # h = cost + heuristic(current,option.dest) + option.cost + heuristic(current,goal)
                            elif reached_goal:
                                h = cost + heuristic(current,init)
                                # h = cost + heuristic(current,option.dest) + option.cost + heuristic(current,init)
                            else:
                                h = cost +  min(heuristic(current,init) , heuristic(current,goal))
                            
                            # h = (1.0/float(h))
                                # h = cost + heuristic(current,option.dest) + option.cost + min(heuristic(current,init) , heuristic(current,goal))
                            heapq.heappush(fringe,(h,[option.dest,path[:],copy.deepcopy(visited),cost,reached_goal,reached_init]))
                else:
                    if path not in solutions:
                        solutions.append(path)
                        l = 0
                        plan = []
                        while l < len(path) - 1:
                            plan.append(options[path[l]][path[l+1]])
                            l += 1
                        hl_plans.append(plan)
                        n -= 1
        # for plan in hl_plans:
        #     print(" ".join(["({},{})".format(o.src.id,o.dest.id) for o in plan]))
        return hl_plans